{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yfv52r2G33jY"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/AI4Finance-Foundation/FinRL-Tutorials/blob/master/1-Introduction/Stock_NeurIPS2018_SB3.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gXaoZs2lh1hi"
   },
   "source": [
    "# Deep Reinforcement Learning for Stock Trading from Scratch: Multiple Stock Trading\n",
    "\n",
    "* **Pytorch Version** \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lGunVt8oLCVS"
   },
   "source": [
    "# Content"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HOzAKQ-SLGX6"
   },
   "source": [
    "* [1. Task Description](#0)\n",
    "* [2. Install Python packages](#1)\n",
    "    * [2.1. Install Packages](#1.1)    \n",
    "    * [2.2. A List of Python Packages](#1.2)\n",
    "    * [2.3. Import Packages](#1.3)\n",
    "    * [2.4. Create Folders](#1.4)\n",
    "* [3. Download and Preprocess Data](#2)\n",
    "* [4. Preprocess Data](#3)        \n",
    "    * [4.1. Technical Indicators](#3.1)\n",
    "    * [4.2. Perform Feature Engineering](#3.2)\n",
    "* [5. Build Market Environment in OpenAI Gym-style](#4)  \n",
    "    * [5.1. Data Split](#4.1)  \n",
    "    * [5.3. Environment for Training](#4.2)    \n",
    "* [6. Train DRL Agents](#5)\n",
    "* [7. Backtesting Performance](#6)  \n",
    "    * [7.1. BackTestStats](#6.1)\n",
    "    * [7.2. BackTestPlot](#6.2)   \n",
    "  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sApkDlD9LIZv"
   },
   "source": [
    "<a id='0'></a>\n",
    "# Part 1. Task Discription"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HjLD2TZSLKZ-"
   },
   "source": [
    "We train a DRL agent for stock trading. This task is modeled as a Markov Decision Process (MDP), and the objective function is maximizing (expected) cumulative return.\n",
    "\n",
    "We specify the state-action-reward as follows:\n",
    "\n",
    "* **State s**: The state space represents an agent's perception of the market environment. Just like a human trader analyzing various information, here our agent passively observes many features and learns by interacting with the market environment (usually by replaying historical data).\n",
    "\n",
    "* **Action a**: The action space includes allowed actions that an agent can take at each state. For example, a ∈ {−1, 0, 1}, where −1, 0, 1 represent\n",
    "selling, holding, and buying. When an action operates multiple shares, a ∈{−k, ..., −1, 0, 1, ..., k}, e.g.. \"Buy\n",
    "10 shares of AAPL\" or \"Sell 10 shares of AAPL\" are 10 or −10, respectively\n",
    "\n",
    "* **Reward function r(s, a, s′)**: Reward is an incentive for an agent to learn a better policy. For example, it can be the change of the portfolio value when taking a at state s and arriving at new state s',  i.e., r(s, a, s′) = v′ − v, where v′ and v represent the portfolio values at state s′ and s, respectively\n",
    "\n",
    "\n",
    "**Market environment**: 30 consituent stocks of Dow Jones Industrial Average (DJIA) index. Accessed at the starting date of the testing period.\n",
    "\n",
    "\n",
    "The data for this case study is obtained from Yahoo Finance API. The data contains Open-High-Low-Close price and volume.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ffsre789LY08"
   },
   "source": [
    "<a id='1'></a>\n",
    "# Part 2. Install Python Packages"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Uy5_PTmOh1hj"
   },
   "source": [
    "<a id='1.1'></a>\n",
    "## 2.1. Install packages\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "mPT0ipYE28wL",
    "outputId": "6dad74d2-c37f-4b86-c584-2436d2ef5bae"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: swig in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (4.3.0)\n",
      "Requirement already satisfied: wrds in /home/random/.local/lib/python3.12/site-packages (3.2.0)\n",
      "Requirement already satisfied: numpy<1.27,>=1.26 in /home/random/.local/lib/python3.12/site-packages (from wrds) (1.26.4)\n",
      "Requirement already satisfied: packaging<23.3 in /home/random/.local/lib/python3.12/site-packages (from wrds) (23.2)\n",
      "Requirement already satisfied: pandas<2.3,>=2.2 in /home/random/.local/lib/python3.12/site-packages (from wrds) (2.2.3)\n",
      "Requirement already satisfied: psycopg2-binary<2.10,>=2.9 in /home/random/.local/lib/python3.12/site-packages (from wrds) (2.9.10)\n",
      "Requirement already satisfied: scipy<1.13,>=1.12 in /home/random/.local/lib/python3.12/site-packages (from wrds) (1.12.0)\n",
      "Requirement already satisfied: sqlalchemy<2.1,>=2 in /home/random/.local/lib/python3.12/site-packages (from wrds) (2.0.36)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /home/random/.local/lib/python3.12/site-packages (from pandas<2.3,>=2.2->wrds) (2.9.0.post0)\n",
      "Requirement already satisfied: pytz>=2020.1 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from pandas<2.3,>=2.2->wrds) (2024.1)\n",
      "Requirement already satisfied: tzdata>=2022.7 in /home/random/.local/lib/python3.12/site-packages (from pandas<2.3,>=2.2->wrds) (2024.2)\n",
      "Requirement already satisfied: typing-extensions>=4.6.0 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from sqlalchemy<2.1,>=2->wrds) (4.12.2)\n",
      "Requirement already satisfied: greenlet!=0.4.17 in /home/random/.local/lib/python3.12/site-packages (from sqlalchemy<2.1,>=2->wrds) (3.1.1)\n",
      "Requirement already satisfied: six>=1.5 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas<2.3,>=2.2->wrds) (1.16.0)\n",
      "Requirement already satisfied: pyportfolioopt in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (1.5.6)\n",
      "Requirement already satisfied: cvxpy>=1.1.19 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from pyportfolioopt) (1.6.0)\n",
      "Requirement already satisfied: ecos<3.0.0,>=2.0.14 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from pyportfolioopt) (2.0.14)\n",
      "Requirement already satisfied: numpy>=1.26.0 in /home/random/.local/lib/python3.12/site-packages (from pyportfolioopt) (1.26.4)\n",
      "Requirement already satisfied: pandas>=0.19 in /home/random/.local/lib/python3.12/site-packages (from pyportfolioopt) (2.2.3)\n",
      "Requirement already satisfied: plotly<6.0.0,>=5.0.0 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from pyportfolioopt) (5.24.1)\n",
      "Requirement already satisfied: scipy>=1.3 in /home/random/.local/lib/python3.12/site-packages (from pyportfolioopt) (1.12.0)\n",
      "Requirement already satisfied: osqp>=0.6.2 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from cvxpy>=1.1.19->pyportfolioopt) (0.6.7.post3)\n",
      "Requirement already satisfied: clarabel>=0.5.0 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from cvxpy>=1.1.19->pyportfolioopt) (0.9.0)\n",
      "Requirement already satisfied: scs>=3.2.4.post1 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from cvxpy>=1.1.19->pyportfolioopt) (3.2.7)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /home/random/.local/lib/python3.12/site-packages (from pandas>=0.19->pyportfolioopt) (2.9.0.post0)\n",
      "Requirement already satisfied: pytz>=2020.1 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from pandas>=0.19->pyportfolioopt) (2024.1)\n",
      "Requirement already satisfied: tzdata>=2022.7 in /home/random/.local/lib/python3.12/site-packages (from pandas>=0.19->pyportfolioopt) (2024.2)\n",
      "Requirement already satisfied: tenacity>=6.2.0 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from plotly<6.0.0,>=5.0.0->pyportfolioopt) (9.0.0)\n",
      "Requirement already satisfied: packaging in /home/random/.local/lib/python3.12/site-packages (from plotly<6.0.0,>=5.0.0->pyportfolioopt) (23.2)\n",
      "Requirement already satisfied: qdldl in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from osqp>=0.6.2->cvxpy>=1.1.19->pyportfolioopt) (0.1.7.post4)\n",
      "Requirement already satisfied: six>=1.5 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas>=0.19->pyportfolioopt) (1.16.0)\n",
      "[sudo] password for random: cmake is already the newest version (3.28.3-1build7).\n",
      "libopenmpi-dev is already the newest version (4.1.6-7ubuntu2).\n",
      "python3-dev is already the newest version (3.12.3-0ubuntu2).\n",
      "zlib1g-dev is already the newest version (1:1.3.dfsg-3.1ubuntu2.1).\n",
      "libgl1-mesa-glx is already the newest version (23.0.4-0ubuntu1~22.04.1).\n",
      "swig is already the newest version (4.2.0-2ubuntu1).\n",
      "0 upgraded, 0 newly installed, 0 to remove and 39 not upgraded.\n",
      "Collecting git+https://github.com/AI4Finance-Foundation/FinRL.git\n",
      "  Cloning https://github.com/AI4Finance-Foundation/FinRL.git to /tmp/pip-req-build-flt95p98\n",
      "  Running command git clone --filter=blob:none --quiet https://github.com/AI4Finance-Foundation/FinRL.git /tmp/pip-req-build-flt95p98\n",
      "  Resolved https://github.com/AI4Finance-Foundation/FinRL.git to commit ef471fcea1f3667442f5ecbf7b4c214610a5dd55\n",
      "  Installing build dependencies ... \u001B[?25ldone\n",
      "\u001B[?25h  Getting requirements to build wheel ... \u001B[?25ldone\n",
      "\u001B[?25h  Preparing metadata (pyproject.toml) ... \u001B[?25ldone\n",
      "\u001B[?25hCollecting elegantrl@ git+https://github.com/AI4Finance-Foundation/ElegantRL.git (from finrl==0.3.6)\n",
      "  Cloning https://github.com/AI4Finance-Foundation/ElegantRL.git to /tmp/pip-install-u43l6ss9/elegantrl_36782baa6d82461e89b600dda61820c8\n",
      "  Running command git clone --filter=blob:none --quiet https://github.com/AI4Finance-Foundation/ElegantRL.git /tmp/pip-install-u43l6ss9/elegantrl_36782baa6d82461e89b600dda61820c8\n",
      "  Resolved https://github.com/AI4Finance-Foundation/ElegantRL.git to commit 59d9a33e2b3ba2d77c052c2810bb61059736d88c\n",
      "  Preparing metadata (setup.py) ... \u001B[?25ldone\n",
      "\u001B[?25hRequirement already satisfied: alpaca-trade-api<4,>=3 in /home/random/.local/lib/python3.12/site-packages (from finrl==0.3.6) (3.2.0)\n",
      "Collecting ccxt<4,>=3 (from finrl==0.3.6)\n",
      "  Using cached ccxt-3.1.60-py2.py3-none-any.whl.metadata (108 kB)\n",
      "Requirement already satisfied: exchange-calendars<5,>=4 in /home/random/.local/lib/python3.12/site-packages (from finrl==0.3.6) (4.6)\n",
      "Collecting jqdatasdk<2,>=1 (from finrl==0.3.6)\n",
      "  Using cached jqdatasdk-1.9.7-py3-none-any.whl.metadata (5.8 kB)\n",
      "Collecting pyfolio<0.10,>=0.9 (from finrl==0.3.6)\n",
      "  Using cached pyfolio-0.9.2.tar.gz (91 kB)\n",
      "  Preparing metadata (setup.py) ... \u001B[?25lerror\n",
      "  \u001B[1;31merror\u001B[0m: \u001B[1msubprocess-exited-with-error\u001B[0m\n",
      "  \n",
      "  \u001B[31m×\u001B[0m \u001B[32mpython setup.py egg_info\u001B[0m did not run successfully.\n",
      "  \u001B[31m│\u001B[0m exit code: \u001B[1;36m1\u001B[0m\n",
      "  \u001B[31m╰─>\u001B[0m \u001B[31m[18 lines of output]\u001B[0m\n",
      "  \u001B[31m   \u001B[0m /tmp/pip-install-u43l6ss9/pyfolio_f61a15f976d345b4a7050d0999ff9c7b/versioneer.py:468: SyntaxWarning: invalid escape sequence '\\s'\n",
      "  \u001B[31m   \u001B[0m   LONG_VERSION_PY['git'] = '''\n",
      "  \u001B[31m   \u001B[0m Traceback (most recent call last):\n",
      "  \u001B[31m   \u001B[0m   File \"<string>\", line 2, in <module>\n",
      "  \u001B[31m   \u001B[0m   File \"<pip-setuptools-caller>\", line 34, in <module>\n",
      "  \u001B[31m   \u001B[0m   File \"/tmp/pip-install-u43l6ss9/pyfolio_f61a15f976d345b4a7050d0999ff9c7b/setup.py\", line 71, in <module>\n",
      "  \u001B[31m   \u001B[0m     version=versioneer.get_version(),\n",
      "  \u001B[31m   \u001B[0m             ^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  \u001B[31m   \u001B[0m   File \"/tmp/pip-install-u43l6ss9/pyfolio_f61a15f976d345b4a7050d0999ff9c7b/versioneer.py\", line 1407, in get_version\n",
      "  \u001B[31m   \u001B[0m     return get_versions()[\"version\"]\n",
      "  \u001B[31m   \u001B[0m            ^^^^^^^^^^^^^^\n",
      "  \u001B[31m   \u001B[0m   File \"/tmp/pip-install-u43l6ss9/pyfolio_f61a15f976d345b4a7050d0999ff9c7b/versioneer.py\", line 1341, in get_versions\n",
      "  \u001B[31m   \u001B[0m     cfg = get_config_from_root(root)\n",
      "  \u001B[31m   \u001B[0m           ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  \u001B[31m   \u001B[0m   File \"/tmp/pip-install-u43l6ss9/pyfolio_f61a15f976d345b4a7050d0999ff9c7b/versioneer.py\", line 399, in get_config_from_root\n",
      "  \u001B[31m   \u001B[0m     parser = configparser.SafeConfigParser()\n",
      "  \u001B[31m   \u001B[0m              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  \u001B[31m   \u001B[0m AttributeError: module 'configparser' has no attribute 'SafeConfigParser'. Did you mean: 'RawConfigParser'?\n",
      "  \u001B[31m   \u001B[0m \u001B[31m[end of output]\u001B[0m\n",
      "  \n",
      "  \u001B[1;35mnote\u001B[0m: This error originates from a subprocess, and is likely not a problem with pip.\n",
      "\u001B[1;31merror\u001B[0m: \u001B[1mmetadata-generation-failed\u001B[0m\n",
      "\n",
      "\u001B[31m×\u001B[0m Encountered error while generating package metadata.\n",
      "\u001B[31m╰─>\u001B[0m See above for output.\n",
      "\n",
      "\u001B[1;35mnote\u001B[0m: This is an issue with the package mentioned above, not pip.\n",
      "\u001B[1;36mhint\u001B[0m: See above for details.\n",
      "\u001B[?25h"
     ]
    }
   ],
   "source": [
    "## install required packages\n",
    "\n",
    "!pip install swig\n",
    "!pip install wrds\n",
    "!pip install pyportfolioopt\n",
    "## install finrl library\n",
    "!pip install -q condacolab\n",
    "import condacolab\n",
    "condacolab.install()\n",
    "!apt-get update -y -qq && apt-get install -y -qq cmake libopenmpi-dev python3-dev zlib1g-dev libgl1-mesa-glx swig\n",
    "!pip install git+https://github.com/AI4Finance-Foundation/FinRL.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "osBHhVysOEzi"
   },
   "source": [
    "\n",
    "<a id='1.2'></a>\n",
    "## 2.2. A list of Python packages \n",
    "* Yahoo Finance API\n",
    "* pandas\n",
    "* numpy\n",
    "* matplotlib\n",
    "* stockstats\n",
    "* OpenAI gym\n",
    "* stable-baselines\n",
    "* tensorflow\n",
    "* pyfolio"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nGv01K8Sh1hn"
   },
   "source": [
    "<a id='1.3'></a>\n",
    "## 2.3. Import Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "lPqeTTwoh1hn",
    "outputId": "e55033fc-48ae-4696-ae45-08b8bef664d5"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-01-04 15:29:19.697527: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
      "2025-01-04 15:29:19.724461: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "E0000 00:00:1736000959.745993   24692 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "E0000 00:00:1736000959.755250   24692 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2025-01-04 15:29:19.798332: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "/home/random/anaconda3/envs/finrl/lib/python3.12/site-packages/pyfolio/pos.py:25: UserWarning: Module \"zipline.assets\" not found; multipliers will not be applied to position notionals.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "# matplotlib.use('Agg')\n",
    "%matplotlib inline\n",
    "\n",
    "from finrl.meta.preprocessor.yahoodownloader import YahooDownloader\n",
    "from finrl.meta.preprocessor.preprocessors import FeatureEngineer, data_split\n",
    "from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv\n",
    "from finrl.agents.stablebaselines3.models import DRLAgent\n",
    "from stable_baselines3.common.logger import configure\n",
    "from finrl.meta.data_processor import DataProcessor\n",
    "from finrl.meta.data_processors.processor_yahoofinance import YahooFinanceProcessor\n",
    "from finrl.plot import backtest_stats, backtest_plot, get_daily_return, get_baseline\n",
    "from pprint import pprint\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"../FinRL\")\n",
    "\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "T2owTj985RW4"
   },
   "source": [
    "<a id='1.4'></a>\n",
    "## 2.4. Create Folders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "RtUc_ofKmpdy"
   },
   "outputs": [],
   "source": [
    "from finrl import config\n",
    "from finrl import config_tickers\n",
    "import os\n",
    "from finrl.main import check_and_make_directories\n",
    "from finrl.config import (\n",
    "    DATA_SAVE_DIR,\n",
    "    TRAINED_MODEL_DIR,\n",
    "    TENSORBOARD_LOG_DIR,\n",
    "    RESULTS_DIR,\n",
    "    INDICATORS,\n",
    "    TRAIN_START_DATE,\n",
    "    TRAIN_END_DATE,\n",
    "    TEST_START_DATE,\n",
    "    TEST_END_DATE,\n",
    "    TRADE_START_DATE,\n",
    "    TRADE_END_DATE,\n",
    ")\n",
    "check_and_make_directories([DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR])\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "A289rQWMh1hq"
   },
   "source": [
    "<a id='2'></a>\n",
    "# Part 3. Download Data\n",
    "Yahoo Finance provides stock data, financial news, financial reports, etc. Yahoo Finance is free.\n",
    "* FinRL uses a class **YahooDownloader** in FinRL-Meta to fetch data via Yahoo Finance API\n",
    "* Call Limit: Using the Public API (without authentication), you are limited to 2,000 requests per hour per IP (or up to a total of 48,000 requests a day)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NPeQ7iS-LoMm"
   },
   "source": [
    "\n",
    "\n",
    "-----\n",
    "class YahooDownloader:\n",
    "    Retrieving daily stock data from\n",
    "    Yahoo Finance API\n",
    "\n",
    "    Attributes\n",
    "    ----------\n",
    "        start_date : str\n",
    "            start date of the data (modified from config.py)\n",
    "        end_date : str\n",
    "            end date of the data (modified from config.py)\n",
    "        ticker_list : list\n",
    "            a list of stock tickers (modified from config.py)\n",
    "\n",
    "    Methods\n",
    "    -------\n",
    "    fetch_data()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 0
    },
    "id": "h3XJnvrbLp-C",
    "outputId": "a03772b5-9cad-463f-e1d6-58d91a70a594"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2020-07-31'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# from config.py, TRAIN_START_DATE is a string\n",
    "TRAIN_START_DATE\n",
    "# from config.py, TRAIN_END_DATE is a string\n",
    "TRAIN_END_DATE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "FUnY8WEfLq3C"
   },
   "outputs": [],
   "source": [
    "TRAIN_START_DATE = '2010-01-01'\n",
    "TRAIN_END_DATE = '2021-10-01'\n",
    "TRADE_START_DATE = '2021-10-01'\n",
    "TRADE_END_DATE = '2023-03-01'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "yCKm4om-s9kE",
    "outputId": "fd758d58-8946-42ee-e2e3-16f4ac74add2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processing AXP (1/3)... 33.33% complete.\n",
      "Processing AMGN (2/3)... 66.67% complete.\n",
      "Processing AAPL (3/3)... 100.00% complete.\n",
      "         Date   Open   High    Low  Close  Adj Close      Volume  tick  day\n",
      "0  2010-01-04   7.62   7.66   7.59   7.64       6.45   493729600  AAPL    3\n",
      "1  2010-01-04  56.63  57.87  56.56  57.72      40.92     5277400  AMGN    3\n",
      "2  2010-01-04  40.81  41.10  40.39  40.92      32.83     6894300   AXP    3\n",
      "3  2010-01-05   7.66   7.70   7.62   7.66       6.46   601904800  AAPL    4\n",
      "4  2010-01-05  57.33  57.69  56.27  57.22      40.56     7882800  AMGN    4\n",
      "5  2010-01-05  40.83  41.23  40.37  40.83      32.76    10641200   AXP    4\n",
      "6  2010-01-06   7.66   7.69   7.53   7.53       6.36   552160000  AAPL    5\n",
      "7  2010-01-06  56.94  57.39  56.50  56.79      40.26     6015100  AMGN    5\n",
      "8  2010-01-06  41.23  41.67  41.17  41.49      33.29     8399400   AXP    5\n",
      "9  2010-01-07   7.56   7.57   7.47   7.52       6.34   477131200  AAPL    6\n",
      "10 2010-01-07  56.41  56.53  54.65  56.27      39.89    10371600  AMGN    6\n",
      "11 2010-01-07  41.26  42.24  41.11  41.98      33.83     8981700   AXP    6\n",
      "12 2010-01-08   7.51   7.57   7.47   7.57       6.39   447610800  AAPL    7\n",
      "13 2010-01-08  56.07  56.83  55.64  56.77      40.24     6576000  AMGN    7\n",
      "14 2010-01-08  41.76  42.48  41.40  41.95      33.80     7907700   AXP    7\n",
      "15 2010-01-11   7.60   7.61   7.44   7.50       6.33   462229600  AAPL   10\n",
      "16 2010-01-11  56.93  57.36  56.62  57.02      40.42     4062700  AMGN   10\n",
      "17 2010-01-11  41.74  41.96  41.25  41.47      33.42     7396000   AXP   10\n",
      "18 2010-01-12   7.47   7.49   7.37   7.42       6.26   594459600  AAPL   11\n",
      "19 2010-01-12  57.14  57.42  54.82  56.03      39.72    11268300  AMGN   11\n",
      "20 2010-01-12  41.27  42.35  41.25  42.02      33.86    12657300   AXP   11\n",
      "21 2010-01-13   7.42   7.53   7.29   7.52       6.35   605892000  AAPL   12\n",
      "22 2010-01-13  56.35  56.75  55.96  56.53      40.07     5056200  AMGN   12\n",
      "23 2010-01-13  41.85  42.24  41.57  42.15      33.96    10137200   AXP   12\n",
      "24 2010-01-14   7.50   7.52   7.47   7.48       6.31   432894000  AAPL   13\n",
      "25 2010-01-14  56.35  56.53  55.91  56.16      39.81     4668900  AMGN   13\n",
      "26 2010-01-14  42.04  42.74  42.02  42.68      34.39     8238400   AXP   13\n",
      "27 2010-01-15   7.53   7.56   7.35   7.35       6.20   594067600  AAPL   14\n",
      "28 2010-01-15  56.03  56.51  55.65  56.25      39.87     7240000  AMGN   14\n",
      "29 2010-01-15  42.52  42.84  42.02  42.39      34.16    13629000   AXP   14\n",
      "30 2010-01-19   7.44   7.69   7.40   7.68       6.48   730007600  AAPL   18\n",
      "31 2010-01-19  56.41  57.75  56.24  57.55      40.80     8570100  AMGN   18\n",
      "32 2010-01-19  42.24  43.05  42.11  42.96      34.62     9533800   AXP   18\n",
      "33 2010-01-20   7.68   7.70   7.48   7.56       6.38   612152800  AAPL   19\n",
      "34 2010-01-20  57.62  57.62  56.41  57.20      40.55     6625700  AMGN   19\n",
      "35 2010-01-20  42.93  43.25  42.26  42.98      34.63    11643000   AXP   19\n",
      "36 2010-01-21   7.57   7.62   7.40   7.43       6.27   608154400  AAPL   20\n",
      "37 2010-01-21  57.43  57.56  56.31  56.63      40.14     5833700  AMGN   20\n",
      "38 2010-01-21  42.99  43.10  41.53  42.16      33.97    16974300   AXP   20\n",
      "39 2010-01-22   7.39   7.41   7.04   7.06       5.96   881767600  AAPL   21\n",
      "40 2010-01-22  56.67  57.30  56.53  56.60      40.12     5967600  AMGN   21\n",
      "41 2010-01-22  41.36  41.49  38.19  38.59      31.09    26170800   AXP   21\n",
      "42 2010-01-25   7.23   7.31   7.15   7.25       6.12  1065699600  AAPL   24\n",
      "43 2010-01-25  56.72  56.79  55.55  55.71      39.49     6719400  AMGN   24\n",
      "44 2010-01-25  39.10  39.29  37.50  37.79      30.45    17587600   AXP   24\n",
      "45 2010-01-26   7.36   7.63   7.24   7.36       6.20  1867110000  AAPL   25\n",
      "46 2010-01-26  56.20  56.87  55.70  56.58      40.11    14880300  AMGN   25\n",
      "47 2010-01-26  37.54  39.23  37.52  38.10      30.70    15709900   AXP   25\n",
      "48 2010-01-27   7.39   7.52   7.13   7.42       6.26  1722568400  AAPL   26\n",
      "49 2010-01-27  56.35  57.88  56.35  57.74      40.93     9695000  AMGN   26\n",
      "50 2010-01-27  37.96  38.84  37.83  38.67      31.16    12908300   AXP   26\n",
      "51 2010-01-28   7.32   7.34   7.10   7.12       6.00  1173502400  AAPL   27\n",
      "52 2010-01-28  57.87  58.78  57.56  58.08      41.17    11638200  AMGN   27\n",
      "53 2010-01-28  38.67  38.67  36.83  37.43      30.16    14148600   AXP   27\n",
      "54 2010-01-29   7.18   7.22   6.79   6.86       5.79  1245952400  AAPL   28\n",
      "55 2010-01-29  58.35  58.93  58.16  58.48      41.45     9465700  AMGN   28\n",
      "56 2010-01-29  37.60  38.77  37.36  37.66      30.35    14219900   AXP   28\n"
     ]
    }
   ],
   "source": [
    "#df = YahooDownloader(start_date = TRAIN_START_DATE,\n",
    "#                     end_date = TRADE_END_DATE,\n",
    "#                     ticker_list = config_tickers.DOW_30_TICKER).fetch_data()\n",
    "yfp = YahooFinanceProcessor()\n",
    "df = yfp.scrap_data(['AXP', 'AMGN', 'AAPL'], '2010-01-01', '2010-02-01')\n",
    "print(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "JzqRRTOX6aFu",
    "outputId": "58a21ede-016a-4eaf-db9f-aeb190b3f939"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['AXP', 'AMGN', 'AAPL', 'BA', 'CAT', 'CSCO', 'CVX', 'GS', 'HD', 'HON', 'IBM', 'INTC', 'JNJ', 'KO', 'JPM', 'MCD', 'MMM', 'MRK', 'MSFT', 'NKE', 'PG', 'TRV', 'UNH', 'CRM', 'VZ', 'V', 'WBA', 'WMT', 'DIS', 'DOW']\n"
     ]
    }
   ],
   "source": [
    "print(config_tickers.DOW_30_TICKER)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "CV3HrZHLh1hy",
    "outputId": "c2cf4956-210b-4811-be12-0c7fd18b923c"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(57, 9)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 0
    },
    "id": "4hYkeaPiICHS",
    "outputId": "6d7a1c0d-15dc-4adc-b776-f1020e173a5c"
   },
   "outputs": [
    {
     "ename": "KeyError",
     "evalue": "'date'",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyError\u001B[0m                                  Traceback (most recent call last)",
      "\u001B[0;32m/tmp/ipykernel_24692/1255811168.py\u001B[0m in \u001B[0;36m?\u001B[0;34m()\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0mdf\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0msort_values\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m'date'\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m'tic'\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0mignore_index\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mhead\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m",
      "\u001B[0;32m~/.local/lib/python3.12/site-packages/pandas/core/frame.py\u001B[0m in \u001B[0;36m?\u001B[0;34m(self, by, axis, ascending, inplace, kind, na_position, ignore_index, key)\u001B[0m\n\u001B[1;32m   7168\u001B[0m                 \u001B[0;34mf\"\u001B[0m\u001B[0;34mLength of ascending (\u001B[0m\u001B[0;34m{\u001B[0m\u001B[0mlen\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mascending\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m}\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\"\u001B[0m  \u001B[0;31m# type: ignore[arg-type]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m   7169\u001B[0m                 \u001B[0;34mf\"\u001B[0m\u001B[0;34m != length of by (\u001B[0m\u001B[0;34m{\u001B[0m\u001B[0mlen\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mby\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m}\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\"\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m   7170\u001B[0m             \u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m   7171\u001B[0m         \u001B[0;32mif\u001B[0m \u001B[0mlen\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mby\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;34m>\u001B[0m \u001B[0;36m1\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m-> 7172\u001B[0;31m             \u001B[0mkeys\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_get_label_or_level_values\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0maxis\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0maxis\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0mx\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mby\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m   7173\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m   7174\u001B[0m             \u001B[0;31m# need to rewrap columns in Series to apply key function\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m   7175\u001B[0m             \u001B[0;32mif\u001B[0m \u001B[0mkey\u001B[0m \u001B[0;32mis\u001B[0m \u001B[0;32mnot\u001B[0m \u001B[0;32mNone\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m~/.local/lib/python3.12/site-packages/pandas/core/generic.py\u001B[0m in \u001B[0;36m?\u001B[0;34m(self, key, axis)\u001B[0m\n\u001B[1;32m   1907\u001B[0m             \u001B[0mvalues\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mxs\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mkey\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0maxis\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mother_axes\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_values\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m   1908\u001B[0m         \u001B[0;32melif\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_is_level_reference\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mkey\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0maxis\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0maxis\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m   1909\u001B[0m             \u001B[0mvalues\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0maxes\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0maxis\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mget_level_values\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mkey\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_values\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m   1910\u001B[0m         \u001B[0;32melse\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m-> 1911\u001B[0;31m             \u001B[0;32mraise\u001B[0m \u001B[0mKeyError\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mkey\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m   1912\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m   1913\u001B[0m         \u001B[0;31m# Check for duplicates\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m   1914\u001B[0m         \u001B[0;32mif\u001B[0m \u001B[0mvalues\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mndim\u001B[0m \u001B[0;34m>\u001B[0m \u001B[0;36m1\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;31mKeyError\u001B[0m: 'date'"
     ]
    }
   ],
   "source": [
    "df.sort_values(['date','tic'],ignore_index=True).head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uqC6c40Zh1iH"
   },
   "source": [
    "# Part 4: Preprocess Data\n",
    "We need to check for missing data and do feature engineering to convert the data point into a state.\n",
    "* **Adding technical indicators**. In practical trading, various information needs to be taken into account, such as historical prices, current holding shares, technical indicators, etc. Here, we demonstrate two trend-following technical indicators: MACD and RSI.\n",
    "* **Adding turbulence index**. Risk-aversion reflects whether an investor prefers to protect the capital. It also influences one's trading strategy when facing different market volatility level. To control the risk in a worst-case scenario, such as financial crisis of 2007–2008, FinRL employs the turbulence index that measures extreme fluctuation of asset price."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "PmKP-1ii3RLS",
    "outputId": "22fecb54-5555-4ec4-cb32-0a54f443e54e"
   },
   "outputs": [],
   "source": [
    "fe = FeatureEngineer(\n",
    "                    use_technical_indicator=True,\n",
    "                    tech_indicator_list = INDICATORS,\n",
    "                    use_vix=True,\n",
    "                    use_turbulence=True,\n",
    "                    user_defined_feature = False)\n",
    "\n",
    "processed = fe.preprocess_data(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Kixon2tR3RLT"
   },
   "outputs": [],
   "source": [
    "list_ticker = processed[\"tic\"].unique().tolist()\n",
    "list_date = list(pd.date_range(processed['date'].min(),processed['date'].max()).astype(str))\n",
    "combination = list(itertools.product(list_date,list_ticker))\n",
    "\n",
    "processed_full = pd.DataFrame(combination,columns=[\"date\",\"tic\"]).merge(processed,on=[\"date\",\"tic\"],how=\"left\")\n",
    "processed_full = processed_full[processed_full['date'].isin(processed['date'])]\n",
    "processed_full = processed_full.sort_values(['date','tic'])\n",
    "\n",
    "processed_full = processed_full.fillna(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 0
    },
    "id": "grvhGJJII3Xn",
    "outputId": "2af27938-0df3-4fea-e86d-7a361e71d2e2"
   },
   "outputs": [],
   "source": [
    "processed_full.sort_values(['date','tic'],ignore_index=True).head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5vdORQ384Qx-"
   },
   "outputs": [],
   "source": [
    "mvo_df = processed_full.sort_values(['date','tic'],ignore_index=True)[['date','tic','close']]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-QsYaY0Dh1iw"
   },
   "source": [
    "<a id='4'></a>\n",
    "# Part 5. Build A Market Environment in OpenAI Gym-style\n",
    "The training process involves observing stock price change, taking an action and reward's calculation. By interacting with the market environment, the agent will eventually derive a trading strategy that may maximize (expected) rewards.\n",
    "\n",
    "Our market environment, based on OpenAI Gym, simulates stock markets with historical market data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5TOhcryx44bb"
   },
   "source": [
    "## Data Split\n",
    "We split the data into training set and testing set as follows:\n",
    "\n",
    "Training data period: 2009-01-01 to 2020-07-01\n",
    "\n",
    "Trading data period: 2020-07-01 to 2021-10-31\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "W0qaVGjLtgbI",
    "outputId": "4f16484e-811e-46cd-efee-54c6b309f5a5"
   },
   "outputs": [],
   "source": [
    "train = data_split(processed_full, TRAIN_START_DATE,TRAIN_END_DATE)\n",
    "trade = data_split(processed_full, TRADE_START_DATE,TRADE_END_DATE)\n",
    "train_length = len(train)\n",
    "trade_length = len(trade)\n",
    "print(train_length)\n",
    "print(trade_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 0
    },
    "id": "p52zNCOhTtLR",
    "outputId": "d708401b-129f-495b-e691-7ab8666d6847"
   },
   "outputs": [],
   "source": [
    "train.tail()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 0
    },
    "id": "k9zU9YaTTvFq",
    "outputId": "9080799c-a150-4414-c2de-a68c5e7c3a85"
   },
   "outputs": [],
   "source": [
    "trade.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "zYN573SOHhxG",
    "outputId": "f5dcfc60-af90-4aa0-8849-11848b3ef619"
   },
   "outputs": [],
   "source": [
    "INDICATORS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Q2zqII8rMIqn",
    "outputId": "b6f16ea3-8f52-44c7-ceb1-f58dabe3d1be"
   },
   "outputs": [],
   "source": [
    "stock_dimension = len(train.tic.unique())\n",
    "state_space = 1 + 2*stock_dimension + len(INDICATORS)*stock_dimension\n",
    "print(f\"Stock Dimension: {stock_dimension}, State Space: {state_space}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AWyp84Ltto19"
   },
   "outputs": [],
   "source": [
    "buy_cost_list = sell_cost_list = [0.001] * stock_dimension\n",
    "num_stock_shares = [0] * stock_dimension\n",
    "\n",
    "env_kwargs = {\n",
    "    \"hmax\": 100,\n",
    "    \"initial_amount\": 1000000,\n",
    "    \"num_stock_shares\": num_stock_shares,\n",
    "    \"buy_cost_pct\": buy_cost_list,\n",
    "    \"sell_cost_pct\": sell_cost_list,\n",
    "    \"state_space\": state_space,\n",
    "    \"stock_dim\": stock_dimension,\n",
    "    \"tech_indicator_list\": INDICATORS,\n",
    "    \"action_space\": stock_dimension,\n",
    "    \"reward_scaling\": 1e-4\n",
    "}\n",
    "\n",
    "\n",
    "e_train_gym = StockTradingEnv(df = train, **env_kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "64EoqOrQjiVf"
   },
   "source": [
    "## Environment for Training\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "xwSvvPjutpqS",
    "outputId": "e8fc8f68-b8c9-47a8-e7d2-a6ed0715d216"
   },
   "outputs": [],
   "source": [
    "env_train, _ = e_train_gym.get_sb_env()\n",
    "print(type(env_train))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HMNR5nHjh1iz"
   },
   "source": [
    "<a id='5'></a>\n",
    "# Part 6: Train DRL Agents\n",
    "* The DRL algorithms are from **Stable Baselines 3**. Users are also encouraged to try **ElegantRL** and **Ray RLlib**.\n",
    "* FinRL includes fine-tuned standard DRL algorithms, such as DQN, DDPG, Multi-Agent DDPG, PPO, SAC, A2C and TD3. We also allow users to\n",
    "design their own DRL algorithms by adapting these DRL algorithms."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "364PsqckttcQ"
   },
   "outputs": [],
   "source": [
    "agent = DRLAgent(env = env_train)\n",
    "\n",
    "if_using_a2c = True\n",
    "if_using_ddpg = True\n",
    "if_using_ppo = True\n",
    "if_using_td3 = True\n",
    "if_using_sac = True\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YDmqOyF9h1iz"
   },
   "source": [
    "### Agent Training: 5 algorithms (A2C, DDPG, PPO, TD3, SAC)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uijiWgkuh1jB"
   },
   "source": [
    "### Agent 1: A2C\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "GUCnkn-HIbmj",
    "outputId": "7112ce2a-0f62-4a9c-c8be-4443779b4ba0"
   },
   "outputs": [],
   "source": [
    "agent = DRLAgent(env = env_train)\n",
    "model_a2c = agent.get_model(\"a2c\")\n",
    "\n",
    "if if_using_a2c:\n",
    "  # set up logger\n",
    "  tmp_path = RESULTS_DIR + '/a2c'\n",
    "  new_logger_a2c = configure(tmp_path, [\"stdout\", \"csv\", \"tensorboard\"])\n",
    "  # Set new logger\n",
    "  model_a2c.set_logger(new_logger_a2c)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "0GVpkWGqH4-D",
    "outputId": "d00d9ef6-7489-4126-f53f-376612f48466"
   },
   "outputs": [],
   "source": [
    "trained_a2c = agent.train_model(model=model_a2c, \n",
    "                             tb_log_name='a2c',\n",
    "                             total_timesteps=50000) if if_using_a2c else None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MRiOtrywfAo1"
   },
   "source": [
    "### Agent 2: DDPG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "M2YadjfnLwgt",
    "outputId": "8c8b5e98-763c-453c-a280-1b4f3ac13510"
   },
   "outputs": [],
   "source": [
    "agent = DRLAgent(env = env_train)\n",
    "model_ddpg = agent.get_model(\"ddpg\")\n",
    "\n",
    "if if_using_ddpg:\n",
    "  # set up logger\n",
    "  tmp_path = RESULTS_DIR + '/ddpg'\n",
    "  new_logger_ddpg = configure(tmp_path, [\"stdout\", \"csv\", \"tensorboard\"])\n",
    "  # Set new logger\n",
    "  model_ddpg.set_logger(new_logger_ddpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "tCDa78rqfO_a",
    "outputId": "35589661-85de-42ca-b9f1-52cde7ded447"
   },
   "outputs": [],
   "source": [
    "trained_ddpg = agent.train_model(model=model_ddpg, \n",
    "                             tb_log_name='ddpg',\n",
    "                             total_timesteps=50000) if if_using_ddpg else None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_gDkU-j-fCmZ"
   },
   "source": [
    "### Agent 3: PPO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "y5D5PFUhMzSV",
    "outputId": "2abd06c0-deca-457b-819b-3059c3f17645"
   },
   "outputs": [],
   "source": [
    "agent = DRLAgent(env = env_train)\n",
    "PPO_PARAMS = {\n",
    "    \"n_steps\": 2048,\n",
    "    \"ent_coef\": 0.01,\n",
    "    \"learning_rate\": 0.00025,\n",
    "    \"batch_size\": 128,\n",
    "}\n",
    "model_ppo = agent.get_model(\"ppo\",model_kwargs = PPO_PARAMS)\n",
    "\n",
    "if if_using_ppo:\n",
    "  # set up logger\n",
    "  tmp_path = RESULTS_DIR + '/ppo'\n",
    "  new_logger_ppo = configure(tmp_path, [\"stdout\", \"csv\", \"tensorboard\"])\n",
    "  # Set new logger\n",
    "  model_ppo.set_logger(new_logger_ppo)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Gt8eIQKYM4G3",
    "outputId": "26365c9a-f608-4dd4-9695-018b98d1036a"
   },
   "outputs": [],
   "source": [
    "trained_ppo = agent.train_model(model=model_ppo, \n",
    "                             tb_log_name='ppo',\n",
    "                             total_timesteps=50000) if if_using_ppo else None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3Zpv4S0-fDBv"
   },
   "source": [
    "### Agent 4: TD3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "JSAHhV4Xc-bh",
    "outputId": "db147b9a-163a-4d03-dd6c-9e89f0e8f421"
   },
   "outputs": [],
   "source": [
    "agent = DRLAgent(env = env_train)\n",
    "TD3_PARAMS = {\"batch_size\": 100, \n",
    "              \"buffer_size\": 1000000, \n",
    "              \"learning_rate\": 0.001}\n",
    "\n",
    "model_td3 = agent.get_model(\"td3\",model_kwargs = TD3_PARAMS)\n",
    "\n",
    "if if_using_td3:\n",
    "  # set up logger\n",
    "  tmp_path = RESULTS_DIR + '/td3'\n",
    "  new_logger_td3 = configure(tmp_path, [\"stdout\", \"csv\", \"tensorboard\"])\n",
    "  # Set new logger\n",
    "  model_td3.set_logger(new_logger_td3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "OSRxNYAxdKpU",
    "outputId": "1d85d74c-54cf-4682-a34b-481a5aafe5d4"
   },
   "outputs": [],
   "source": [
    "trained_td3 = agent.train_model(model=model_td3, \n",
    "                             tb_log_name='td3',\n",
    "                             total_timesteps=50000) if if_using_td3 else None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Dr49PotrfG01"
   },
   "source": [
    "### Agent 5: SAC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "xwOhVjqRkCdM",
    "outputId": "9018f9ed-0dff-4b75-c0b2-7566784c52cf"
   },
   "outputs": [],
   "source": [
    "agent = DRLAgent(env = env_train)\n",
    "SAC_PARAMS = {\n",
    "    \"batch_size\": 128,\n",
    "    \"buffer_size\": 100000,\n",
    "    \"learning_rate\": 0.0001,\n",
    "    \"learning_starts\": 100,\n",
    "    \"ent_coef\": \"auto_0.1\",\n",
    "}\n",
    "\n",
    "model_sac = agent.get_model(\"sac\",model_kwargs = SAC_PARAMS)\n",
    "\n",
    "if if_using_sac:\n",
    "  # set up logger\n",
    "  tmp_path = RESULTS_DIR + '/sac'\n",
    "  new_logger_sac = configure(tmp_path, [\"stdout\", \"csv\", \"tensorboard\"])\n",
    "  # Set new logger\n",
    "  model_sac.set_logger(new_logger_sac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "K8RSdKCckJyH",
    "outputId": "bfa91496-f7e6-4d0f-fb77-bc9dd1797e81"
   },
   "outputs": [],
   "source": [
    "trained_sac = agent.train_model(model=model_sac, \n",
    "                             tb_log_name='sac',\n",
    "                             total_timesteps=50000) if if_using_sac else None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "f2wZgkQXh1jE"
   },
   "source": [
    "## In-sample Performance\n",
    "\n",
    "Assume that the initial capital is $1,000,000."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bEv5KGC8h1jE"
   },
   "source": [
    "### Set turbulence threshold\n",
    "Set the turbulence threshold to be greater than the maximum of insample turbulence data. If current turbulence index is greater than the threshold, then we assume that the current market is volatile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "efwBi84ch1jE"
   },
   "outputs": [],
   "source": [
    "data_risk_indicator = processed_full[(processed_full.date<TRAIN_END_DATE) & (processed_full.date>=TRAIN_START_DATE)]\n",
    "insample_risk_indicator = data_risk_indicator.drop_duplicates(subset=['date'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "VHZMBpSqh1jG",
    "outputId": "3164bf6e-3b83-4bbf-ecd4-7688c6309e8c"
   },
   "outputs": [],
   "source": [
    "insample_risk_indicator.vix.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "BDkszkMloRWT",
    "outputId": "7e36e119-63e2-4379-f110-490836222522"
   },
   "outputs": [],
   "source": [
    "insample_risk_indicator.vix.quantile(0.996)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "AL7hs7svnNWT",
    "outputId": "13abfde5-de24-40b7-921e-385dd435b3e8"
   },
   "outputs": [],
   "source": [
    "insample_risk_indicator.turbulence.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "N78hfHckoqJ9",
    "outputId": "b5f650e9-cf0a-4481-b519-b77c8a0b1b2a"
   },
   "outputs": [],
   "source": [
    "insample_risk_indicator.turbulence.quantile(0.996)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "U5mmgQF_h1jQ"
   },
   "source": [
    "### Trading (Out-of-sample Performance)\n",
    "\n",
    "We update periodically in order to take full advantage of the data, e.g., retrain quarterly, monthly or weekly. We also tune the parameters along the way, in this notebook we use the in-sample data from 2009-01 to 2020-07 to tune the parameters once, so there is some alpha decay here as the length of trade date extends. \n",
    "\n",
    "Numerous hyperparameters – e.g. the learning rate, the total number of samples to train on – influence the learning process and are usually determined by testing some variations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cIqoV0GSI52v"
   },
   "outputs": [],
   "source": [
    "e_trade_gym = StockTradingEnv(df = trade, turbulence_threshold = 70,risk_indicator_col='vix', **env_kwargs)\n",
    "# env_trade, obs_trade = e_trade_gym.get_sb_env()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 357
    },
    "id": "W_XNgGsBMeVw",
    "outputId": "13588f5a-daef-4a7b-c116-c737bf61e994"
   },
   "outputs": [],
   "source": [
    "trade.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "lbFchno5j3xs",
    "outputId": "5df880d8-ff14-4104-a2f8-a2d1a417cc1c"
   },
   "outputs": [],
   "source": [
    "trained_moedl = trained_a2c\n",
    "df_account_value_a2c, df_actions_a2c = DRLAgent.DRL_prediction(\n",
    "    model=trained_moedl, \n",
    "    environment = e_trade_gym)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "JbYljWGjj3pH",
    "outputId": "2fb2632a-dd77-40f2-eeff-e4b3385727f2"
   },
   "outputs": [],
   "source": [
    "trained_moedl = trained_ddpg\n",
    "df_account_value_ddpg, df_actions_ddpg = DRLAgent.DRL_prediction(\n",
    "    model=trained_moedl, \n",
    "    environment = e_trade_gym)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "74jNP2DBj3hb",
    "outputId": "9659e354-3d56-4fe3-b6bb-81777d179c51"
   },
   "outputs": [],
   "source": [
    "trained_moedl = trained_ppo\n",
    "df_account_value_ppo, df_actions_ppo = DRLAgent.DRL_prediction(\n",
    "    model=trained_moedl, \n",
    "    environment = e_trade_gym)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "S7VyGGJPj3SH",
    "outputId": "a65b52c5-aba0-4e48-b111-481b514fcce2"
   },
   "outputs": [],
   "source": [
    "trained_moedl = trained_td3\n",
    "df_account_value_td3, df_actions_td3 = DRLAgent.DRL_prediction(\n",
    "    model=trained_moedl, \n",
    "    environment = e_trade_gym)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "eLOnL5eYh1jR",
    "outputId": "3d9bf94b-2bb5-4091-dc7f-bfe2851dc0be"
   },
   "outputs": [],
   "source": [
    "trained_moedl = trained_sac\n",
    "df_account_value_sac, df_actions_sac = DRLAgent.DRL_prediction(\n",
    "    model=trained_moedl, \n",
    "    environment = e_trade_gym)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ERxw3KqLkcP4",
    "outputId": "219b1298-4a18-41a3-8390-788739158dd7"
   },
   "outputs": [],
   "source": [
    "df_account_value_a2c.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GcE-t08w6DaW"
   },
   "source": [
    "<a id='7'></a>\n",
    "# Part 6.5: Mean Variance Optimization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GzyHU-RokTaj"
   },
   "source": [
    "Mean Variance optimization is a very classic strategy in portfolio management. Here, we go through the whole process to do the mean variance optimization and add it as a baseline to compare.\n",
    "\n",
    "First, process dataframe to the form for MVO weight calculation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZaxdYAdRcA67"
   },
   "outputs": [],
   "source": [
    "def process_df_for_mvo(df):\n",
    "  df = df.sort_values(['date','tic'],ignore_index=True)[['date','tic','close']]\n",
    "  fst = df\n",
    "  fst = fst.iloc[0:stock_dimension, :]\n",
    "  tic = fst['tic'].tolist()\n",
    "\n",
    "  mvo = pd.DataFrame()\n",
    "\n",
    "  for k in range(len(tic)):\n",
    "    mvo[tic[k]] = 0\n",
    "\n",
    "  for i in range(df.shape[0]//stock_dimension):\n",
    "    n = df\n",
    "    n = n.iloc[i * stock_dimension:(i+1) * stock_dimension, :]\n",
    "    date = n['date'][i*stock_dimension]\n",
    "    mvo.loc[date] = n['close'].tolist()\n",
    "  \n",
    "  return mvo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tcHDZ7hFkdyL"
   },
   "source": [
    "### Helper functions for mean returns and variance-covariance matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gKjY9bvYcEkb"
   },
   "outputs": [],
   "source": [
    "# Codes in this section partially refer to Dr G A Vijayalakshmi Pai\n",
    "\n",
    "# https://www.kaggle.com/code/vijipai/lesson-5-mean-variance-optimization-of-portfolios/notebook\n",
    "\n",
    "def StockReturnsComputing(StockPrice, Rows, Columns): \n",
    "  import numpy as np \n",
    "  StockReturn = np.zeros([Rows-1, Columns]) \n",
    "  for j in range(Columns):        # j: Assets \n",
    "    for i in range(Rows-1):     # i: Daily Prices \n",
    "      StockReturn[i,j]=((StockPrice[i+1, j]-StockPrice[i,j])/StockPrice[i,j])* 100 \n",
    "      \n",
    "  return StockReturn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CPnMNonxkj-I"
   },
   "source": [
    "### Calculate the weights for mean-variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wdF2erPNcVd3"
   },
   "outputs": [],
   "source": [
    "train_mvo = data_split(processed_full, TRAIN_START_DATE,TRAIN_END_DATE).reset_index()\n",
    "trade_mvo = data_split(processed_full, TRADE_START_DATE,TRADE_END_DATE).reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "9-64xYTOcJ36",
    "outputId": "5cf98bac-c467-4ef1-e98c-2bb858a848c2"
   },
   "outputs": [],
   "source": [
    "StockData = process_df_for_mvo(train_mvo)\n",
    "TradeData = process_df_for_mvo(trade_mvo)\n",
    "\n",
    "TradeData.to_numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "28q2-ebfcfbu",
    "outputId": "3a51ec82-f586-4462-f5d1-604017ffa1fe"
   },
   "outputs": [],
   "source": [
    "#compute asset returns\n",
    "arStockPrices = np.asarray(StockData)\n",
    "[Rows, Cols]=arStockPrices.shape\n",
    "arReturns = StockReturnsComputing(arStockPrices, Rows, Cols)\n",
    "\n",
    "#compute mean returns and variance covariance matrix of returns\n",
    "meanReturns = np.mean(arReturns, axis = 0)\n",
    "covReturns = np.cov(arReturns, rowvar=False)\n",
    " \n",
    "#set precision for printing results\n",
    "np.set_printoptions(precision=3, suppress = True)\n",
    "\n",
    "#display mean returns and variance-covariance matrix of returns\n",
    "print('Mean returns of assets in k-portfolio 1\\n', meanReturns)\n",
    "print('Variance-Covariance matrix of returns\\n', covReturns)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ei3f_NxDkpOx"
   },
   "source": [
    "### Use PyPortfolioOpt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "bHc3FC3Hckay",
    "outputId": "6585f4b7-fda4-4d83-c3cc-38c5ed750aea"
   },
   "outputs": [],
   "source": [
    "from pypfopt.efficient_frontier import EfficientFrontier\n",
    "\n",
    "ef_mean = EfficientFrontier(meanReturns, covReturns, weight_bounds=(0, 0.5))\n",
    "raw_weights_mean = ef_mean.max_sharpe()\n",
    "cleaned_weights_mean = ef_mean.clean_weights()\n",
    "mvo_weights = np.array([1000000 * cleaned_weights_mean[i] for i in range(29)])\n",
    "mvo_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "iFiwDj29ck9s",
    "outputId": "1e4c7967-c5af-43de-a858-beadfef5116c"
   },
   "outputs": [],
   "source": [
    "LastPrice = np.array([1/p for p in StockData.tail(1).to_numpy()[0]])\n",
    "Initial_Portfolio = np.multiply(mvo_weights, LastPrice)\n",
    "Initial_Portfolio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wbcVsNYfcn2B"
   },
   "outputs": [],
   "source": [
    "Portfolio_Assets = TradeData @ Initial_Portfolio\n",
    "MVO_result = pd.DataFrame(Portfolio_Assets, columns=[\"Mean Var\"])\n",
    "# MVO_result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "W6vvNSC6h1jZ"
   },
   "source": [
    "<a id='6'></a>\n",
    "# Part 7: Backtesting Results\n",
    "Backtesting plays a key role in evaluating the performance of a trading strategy. Automated backtesting tool is preferred because it reduces the human error. We usually use the Quantopian pyfolio package to backtest our trading strategies. It is easy to use and consists of various individual plots that provide a comprehensive image of the performance of a trading strategy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "KeDeGAc9VrEg",
    "outputId": "fe8802d9-e883-48fb-ed8d-36a8236322f7"
   },
   "outputs": [],
   "source": [
    "df_result_a2c = df_account_value_a2c.set_index(df_account_value_a2c.columns[0])\n",
    "df_result_a2c.rename(columns = {'account_value':'a2c'}, inplace = True)\n",
    "df_result_ddpg = df_account_value_ddpg.set_index(df_account_value_ddpg.columns[0])\n",
    "df_result_ddpg.rename(columns = {'account_value':'ddpg'}, inplace = True)\n",
    "df_result_td3 = df_account_value_td3.set_index(df_account_value_td3.columns[0])\n",
    "df_result_td3.rename(columns = {'account_value':'td3'}, inplace = True)\n",
    "df_result_ppo = df_account_value_ppo.set_index(df_account_value_ppo.columns[0])\n",
    "df_result_ppo.rename(columns = {'account_value':'ppo'}, inplace = True)\n",
    "df_result_sac = df_account_value_sac.set_index(df_account_value_sac.columns[0])\n",
    "df_result_sac.rename(columns = {'account_value':'sac'}, inplace = True)\n",
    "df_account_value_a2c.to_csv(\"df_account_value_a2c.csv\")\n",
    "#baseline stats\n",
    "print(\"==============Get Baseline Stats===========\")\n",
    "df_dji_ = get_baseline(\n",
    "        ticker=\"^DJI\", \n",
    "        start = TRADE_START_DATE,\n",
    "        end = TRADE_END_DATE)\n",
    "stats = backtest_stats(df_dji_, value_col_name = 'close')\n",
    "df_dji = pd.DataFrame()\n",
    "df_dji['date'] = df_account_value_a2c['date']\n",
    "df_dji['account_value'] = df_dji_['close'] / df_dji_['close'][0] * env_kwargs[\"initial_amount\"]\n",
    "df_dji.to_csv(\"df_dji.csv\")\n",
    "df_dji = df_dji.set_index(df_dji.columns[0])\n",
    "df_dji.to_csv(\"df_dji+.csv\")\n",
    "\n",
    "result = pd.DataFrame()\n",
    "result = pd.merge(result, df_result_a2c, how='outer', left_index=True, right_index=True)\n",
    "result = pd.merge(result, df_result_ddpg, how='outer', left_index=True, right_index=True)\n",
    "result = pd.merge(result, df_result_td3, how='outer', left_index=True, right_index=True)\n",
    "result = pd.merge(result, df_result_ppo, how='outer', left_index=True, right_index=True)\n",
    "result = pd.merge(result, df_result_sac, how='outer', left_index=True, right_index=True)\n",
    "result = pd.merge(result, MVO_result, how='outer', left_index=True, right_index=True)\n",
    "print(result.head())\n",
    "result = pd.merge(result, df_dji, how='outer', left_index=True, right_index=True)\n",
    "# result.columns = ['a2c', 'ddpg', 'td3', 'ppo', 'sac', 'mean var', 'dji']\n",
    "\n",
    "# print(\"result: \", result)\n",
    "result.to_csv(\"result.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 455
    },
    "id": "WLapAJTri_7B",
    "outputId": "d9625b21-8814-4ec5-bc6e-3a331be40856"
   },
   "outputs": [],
   "source": [
    "df_result_ddpg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 498
    },
    "id": "6xRfrqK4RVfq",
    "outputId": "81bdf0b6-6471-4997-8ea0-a97ec5772d39"
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "plt.rcParams[\"figure.figsize\"] = (15,5)\n",
    "plt.figure();\n",
    "result.plot();"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "Uy5_PTmOh1hj",
    "A289rQWMh1hq",
    "uqC6c40Zh1iH",
    "-QsYaY0Dh1iw",
    "uijiWgkuh1jB",
    "MRiOtrywfAo1",
    "_gDkU-j-fCmZ",
    "3Zpv4S0-fDBv"
   ],
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  },
  "vscode": {
   "interpreter": {
    "hash": "54cefccbf0f07c9750f12aa115c023dfa5ed4acecf9e7ad3bc9391869be60d0c"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
